package edu.northwestern.cbits.purple_robot_manager.tests.models; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import junit.framework.Assert; import android.content.Context; import edu.northwestern.cbits.purple_robot_manager.R; import edu.northwestern.cbits.purple_robot_manager.models.Model; import edu.northwestern.cbits.purple_robot_manager.models.ModelManager; import edu.northwestern.cbits.purple_robot_manager.models.trees.LeafNode; import edu.northwestern.cbits.purple_robot_manager.models.trees.TreeNode; import edu.northwestern.cbits.purple_robot_manager.models.trees.TreeNode.TreeNodeException; import edu.northwestern.cbits.purple_robot_manager.models.trees.parsers.TreeNodeParser; import edu.northwestern.cbits.purple_robot_manager.models.trees.parsers.TreeNodeParser.ParserNotFound; import edu.northwestern.cbits.purple_robot_manager.tests.RobotTestCase; public class MatlabTreeModelTestCase extends RobotTestCase { private static final String MODEL_URI = "file:///android_asset/test_data/matlab-model.json"; public MatlabTreeModelTestCase(Context context, int priority) { super(context, priority); } public void test() { if (this.isSelected(this._context) == false) return; try { ArrayList<String> lines = new ArrayList<>(); StringBuilder sb = new StringBuilder(); InputStream file = this._context.getAssets().open("test_data/matlab-tree.txt"); BufferedReader in = new BufferedReader(new InputStreamReader(file, "UTF-8")); String line = null; while ((line = in.readLine()) != null) { sb.append(line + "\n"); lines.add(line); } in.close(); TreeNode node = TreeNodeParser.parseString(sb.toString()); HashMap<String, Object> world = new HashMap<>(); // Outputs class at line 508. world.put("x10", -1.0); world.put("x91", -1.0); world.put("x41", -2.0); world.put("x6", -1.0); Map<String, Object> prediction = node.fetchPrediction(world); Assert.assertEquals("MATLAB1", "3", prediction.get(LeafNode.PREDICTION)); world.put("x6", 0.0); world.put("x30", -1.0); prediction = node.fetchPrediction(world); Assert.assertEquals("MATLAB2", "6", prediction.get(LeafNode.PREDICTION)); Assert.assertEquals("MATLAB3", " 1 if x10<-0.464422 then node 2 elseif x10>=-0.464422 then node 3 else 8", lines.get(0)); Assert.assertEquals("MATLAB4", "1815 class = 1", lines.get(1814)); Assert.assertEquals("MATLAB5", "1455 if x18<0.26036 then node 1560 elseif x18>=0.26036 then node 1561 else 4", lines.get(1454)); Assert.assertEquals("MATLAB6", " 894 if x100<1.42992 then node 1092 elseif x100>=1.42992 then node 1093 else 6", lines.get(893)); Assert.assertEquals("MATLAB7", " 3 if x3<0.392665 then node 6 elseif x3>=0.392665 then node 7 else 8", lines.get(2)); Assert.assertEquals("MATLAB8", " 9 if x108<-0.0457584 then node 18 elseif x108>=-0.0457584 then node 19 else 6", lines.get(8)); world.put("x91", 0.0); world.put("x98", 0.0); world.put("x6", -2.0); world.put("x103", -1.0); prediction = node.fetchPrediction(world); Assert.assertEquals("MATLAB9", "2", prediction.get(LeafNode.PREDICTION)); world.clear(); try { prediction = node.fetchPrediction(world); // Should throw exception before getting here... Assert.fail("MATLAB100"); } catch (TreeNodeException e) { } } catch (ParserNotFound e) { e.printStackTrace(); Assert.fail("MATLAB100"); } catch (TreeNodeException e) { e.printStackTrace(); Assert.fail("MATLAB101"); } catch (IOException e) { e.printStackTrace(); Assert.fail("MATLAB102"); } ModelManager models = ModelManager.getInstance(this._context); models.addModel(MatlabTreeModelTestCase.MODEL_URI); try { Thread.sleep(1000); } catch (InterruptedException e) { } Assert.assertNotNull("MATLAB200", models.fetchModelByName(this._context, MatlabTreeModelTestCase.MODEL_URI)); Assert.assertNull("MATLAB201", models.fetchModelByTitle(this._context, MatlabTreeModelTestCase.MODEL_URI)); Assert.assertNotNull("MATLAB202", models.fetchModelByTitle(this._context, "Matlab Tree Model Test")); HashMap<String, Object> world = new HashMap<>(); // Outputs class at line 508. world.put("x10", -1.0); world.put("x91", -1.0); world.put("x41", -2.0); world.put("x6", -1.0); Model matlab = models.fetchModelByTitle(this._context, "Matlab Tree Model Test"); matlab.predict(this._context, world); try { Thread.sleep(2000); } catch (InterruptedException e) { } Assert.assertEquals("MATLAB203", "3", matlab.latestPrediction(this._context).get(LeafNode.PREDICTION)); models.deleteModel(MatlabTreeModelTestCase.MODEL_URI); } public int estimatedMinutes() { return 1; } public String name(Context context) { return context.getString(R.string.name_matlab_tree_model_test); } }